{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b0ee5aff-46d6-47c5-8306-901cfd9206d9",
   "metadata": {},
   "source": [
    "# Fine-tuning DistilBERT for question answering\n",
    "\n",
    "This guide describes fine-tuning DistilBERT model with Stanford Question Answering Dataset (SQuAD) for question-answering using Kubeflow Trainer.\n",
    "\n",
    "This guide is adapted from HuggingFace question answering task recipe page: https://huggingface.co/docs/transformers/en/tasks/question_answering\n",
    "\n",
    "Pretrained DistilBERT: https://huggingface.co/docs/transformers/en/model_doc/distilbert\n",
    "\n",
    "SQuAD dataset: https://huggingface.co/datasets/rajpurkar/squad"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c31bc8f2",
   "metadata": {},
   "source": [
    "# Install the Kubeflow SDK"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4504947e33b6021a",
   "metadata": {},
   "source": [
    "You need to install the Kubeflow SDK to interact with Kubeflow Trainer APIs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5e86ea307b3eec9",
   "metadata": {},
   "outputs": [],
   "source": "# !pip install -U kubeflow"
  },
  {
   "cell_type": "markdown",
   "id": "21c20e6ec87dcec2",
   "metadata": {},
   "source": [
    "Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "10606685",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: cloudpathlib[all] in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (0.21.0)\n",
      "Requirement already satisfied: transformers[torch] in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (4.51.3)\n",
      "Requirement already satisfied: filelock in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (3.18.0)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (0.30.2)\n",
      "Requirement already satisfied: numpy>=1.17 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (2.2.5)\n",
      "Requirement already satisfied: packaging>=20.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (25.0)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (6.0.2)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (2024.11.6)\n",
      "Requirement already satisfied: requests in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (2.32.3)\n",
      "Requirement already satisfied: tokenizers<0.22,>=0.21 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (0.21.1)\n",
      "Requirement already satisfied: safetensors>=0.4.3 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (0.5.3)\n",
      "Requirement already satisfied: tqdm>=4.27 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (4.67.1)\n",
      "Requirement already satisfied: torch>=2.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (2.7.0)\n",
      "Requirement already satisfied: accelerate>=0.26.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from transformers[torch]) (1.6.0)\n",
      "Requirement already satisfied: psutil in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from accelerate>=0.26.0->transformers[torch]) (7.0.0)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from huggingface-hub<1.0,>=0.30.0->transformers[torch]) (2025.3.2)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from huggingface-hub<1.0,>=0.30.0->transformers[torch]) (4.13.2)\n",
      "Requirement already satisfied: setuptools in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from torch>=2.0->transformers[torch]) (80.0.0)\n",
      "Requirement already satisfied: sympy>=1.13.3 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from torch>=2.0->transformers[torch]) (1.14.0)\n",
      "Requirement already satisfied: networkx in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from torch>=2.0->transformers[torch]) (3.4.2)\n",
      "Requirement already satisfied: jinja2 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from torch>=2.0->transformers[torch]) (3.1.6)\n",
      "Requirement already satisfied: azure-storage-blob>=12 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from cloudpathlib[all]) (12.25.1)\n",
      "Requirement already satisfied: azure-storage-file-datalake>=12 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from cloudpathlib[all]) (12.20.0)\n",
      "Requirement already satisfied: google-cloud-storage in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from cloudpathlib[all]) (3.1.0)\n",
      "Requirement already satisfied: boto3>=1.34.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from cloudpathlib[all]) (1.38.3)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from requests->transformers[torch]) (3.4.1)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from requests->transformers[torch]) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from requests->transformers[torch]) (2.4.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from requests->transformers[torch]) (2025.4.26)\n",
      "Requirement already satisfied: azure-core>=1.30.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from azure-storage-blob>=12->cloudpathlib[all]) (1.33.0)\n",
      "Requirement already satisfied: cryptography>=2.1.4 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from azure-storage-blob>=12->cloudpathlib[all]) (44.0.2)\n",
      "Requirement already satisfied: isodate>=0.6.1 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from azure-storage-blob>=12->cloudpathlib[all]) (0.7.2)\n",
      "Requirement already satisfied: botocore<1.39.0,>=1.38.3 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from boto3>=1.34.0->cloudpathlib[all]) (1.38.3)\n",
      "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from boto3>=1.34.0->cloudpathlib[all]) (1.0.1)\n",
      "Requirement already satisfied: s3transfer<0.13.0,>=0.12.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from boto3>=1.34.0->cloudpathlib[all]) (0.12.0)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from sympy>=1.13.3->torch>=2.0->transformers[torch]) (1.3.0)\n",
      "Requirement already satisfied: google-auth<3.0dev,>=2.26.1 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-cloud-storage->cloudpathlib[all]) (2.39.0)\n",
      "Requirement already satisfied: google-api-core<3.0.0dev,>=2.15.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-cloud-storage->cloudpathlib[all]) (2.25.0rc0)\n",
      "Requirement already satisfied: google-cloud-core<3.0dev,>=2.4.2 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-cloud-storage->cloudpathlib[all]) (2.4.3)\n",
      "Requirement already satisfied: google-resumable-media>=2.7.2 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-cloud-storage->cloudpathlib[all]) (2.7.2)\n",
      "Requirement already satisfied: google-crc32c<2.0dev,>=1.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-cloud-storage->cloudpathlib[all]) (1.7.1)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from jinja2->torch>=2.0->transformers[torch]) (3.0.2)\n",
      "Requirement already satisfied: six>=1.11.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from azure-core>=1.30.0->azure-storage-blob>=12->cloudpathlib[all]) (1.17.0)\n",
      "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from botocore<1.39.0,>=1.38.3->boto3>=1.34.0->cloudpathlib[all]) (2.9.0.post0)\n",
      "Requirement already satisfied: cffi>=1.12 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from cryptography>=2.1.4->azure-storage-blob>=12->cloudpathlib[all]) (1.17.1)\n",
      "Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-api-core<3.0.0dev,>=2.15.0->google-cloud-storage->cloudpathlib[all]) (1.70.0)\n",
      "Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-api-core<3.0.0dev,>=2.15.0->google-cloud-storage->cloudpathlib[all]) (6.30.2)\n",
      "Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-api-core<3.0.0dev,>=2.15.0->google-cloud-storage->cloudpathlib[all]) (1.26.1)\n",
      "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-auth<3.0dev,>=2.26.1->google-cloud-storage->cloudpathlib[all]) (5.5.2)\n",
      "Requirement already satisfied: pyasn1-modules>=0.2.1 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-auth<3.0dev,>=2.26.1->google-cloud-storage->cloudpathlib[all]) (0.4.2)\n",
      "Requirement already satisfied: rsa<5,>=3.1.4 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from google-auth<3.0dev,>=2.26.1->google-cloud-storage->cloudpathlib[all]) (4.9.1)\n",
      "Requirement already satisfied: pycparser in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from cffi>=1.12->cryptography>=2.1.4->azure-storage-blob>=12->cloudpathlib[all]) (2.22)\n",
      "Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /Users/andrew/git/trainer/.venv/lib/python3.13/site-packages (from pyasn1-modules>=0.2.1->google-auth<3.0dev,>=2.26.1->google-cloud-storage->cloudpathlib[all]) (0.6.1)\n",
      "\n",
      "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m25.0.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m25.1\u001B[0m\n",
      "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install \"cloudpathlib[all]\" \"transformers[torch]\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a15b91f",
   "metadata": {},
   "source": [
    "# Define the HuggingFace training script\n",
    "\n",
    "We need to wrap our training script into a function to create the Kubeflow TrainJob."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24e7f396-32ce-4d23-b76f-9684de470471",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_distilbert(model_name: str, bucket: str = None):\n",
    "    import os\n",
    "\n",
    "    from cloudpathlib import CloudPath\n",
    "    from datasets import load_dataset\n",
    "    import torch\n",
    "    from transformers import AutoTokenizer, DefaultDataCollator, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "\n",
    "    import torch.distributed as dist\n",
    "\n",
    "    # Initialize distributed environment\n",
    "    _, backend = (\"cuda\", \"nccl\") if torch.cuda.is_available() else (\"cpu\", \"gloo\")\n",
    "    dist.init_process_group(backend=backend)\n",
    "\n",
    "    local_rank = int(os.getenv(\"LOCAL_RANK\", 0))\n",
    "    print(\n",
    "        \"Distributed Training with WORLD_SIZE: {}, RANK: {}, LOCAL_RANK: {}.\".format(\n",
    "            dist.get_world_size(),\n",
    "            dist.get_rank(),\n",
    "            local_rank,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    # Download the dataset and tokenizer\n",
    "    squad = load_dataset(\"squad\", split=\"train[:100]\")    \n",
    "\n",
    "    squad = squad.train_test_split(test_size=0.2, shuffle=False)\n",
    "    \n",
    "    tokenizer = AutoTokenizer.from_pretrained(f'distilbert/{model_name}')\n",
    "    \n",
    "    # Define the preprocessing function\n",
    "    def preprocess_function(examples):\n",
    "        questions = [q.strip() for q in examples[\"question\"]]\n",
    "        inputs = tokenizer(\n",
    "            questions,\n",
    "            examples[\"context\"],\n",
    "            max_length=384,\n",
    "            truncation=\"only_second\",\n",
    "            return_offsets_mapping=True,\n",
    "            padding=\"max_length\",\n",
    "        )\n",
    "    \n",
    "        offset_mapping = inputs.pop(\"offset_mapping\")\n",
    "        answers = examples[\"answers\"]\n",
    "        start_positions = []\n",
    "        end_positions = []\n",
    "    \n",
    "        for i, offset in enumerate(offset_mapping):\n",
    "            answer = answers[i]\n",
    "            start_char = answer[\"answer_start\"][0]\n",
    "            end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n",
    "            sequence_ids = inputs.sequence_ids(i)\n",
    "    \n",
    "            # Find the start and end of the context\n",
    "            idx = 0\n",
    "            while sequence_ids[idx] != 1:\n",
    "                idx += 1\n",
    "            context_start = idx\n",
    "            while sequence_ids[idx] == 1:\n",
    "                idx += 1\n",
    "            context_end = idx - 1\n",
    "    \n",
    "            # If the answer is not fully inside the context, label it (0, 0)\n",
    "            if offset[context_start][0] > end_char or offset[context_end][1] < start_char:\n",
    "                start_positions.append(0)\n",
    "                end_positions.append(0)\n",
    "            else:\n",
    "                # Otherwise it's the start and end token positions\n",
    "                idx = context_start\n",
    "                while idx <= context_end and offset[idx][0] <= start_char:\n",
    "                    idx += 1\n",
    "                start_positions.append(idx - 1)\n",
    "    \n",
    "                idx = context_end\n",
    "                while idx >= context_start and offset[idx][1] >= end_char:\n",
    "                    idx -= 1\n",
    "                end_positions.append(idx + 1)\n",
    "    \n",
    "        inputs[\"start_positions\"] = start_positions\n",
    "        inputs[\"end_positions\"] = end_positions\n",
    "        return inputs\n",
    "        \n",
    "    # Apply the preprocessing function to the dataset\n",
    "    tokenized_squad = squad.map(preprocess_function, batched=True, remove_columns=squad[\"train\"].column_names)\n",
    "        \n",
    "    # Create a batch of examples using DefaultDataCollator\n",
    "    data_collator = DefaultDataCollator()\n",
    "\n",
    "    # Load the model\n",
    "    model = AutoModelForQuestionAnswering.from_pretrained(f'distilbert/{model_name}')\n",
    "\n",
    "    # Define training hyperparameters\n",
    "    training_args = TrainingArguments(\n",
    "        output_dir=model_name,\n",
    "        eval_strategy=\"epoch\",\n",
    "        learning_rate=2e-5,\n",
    "        per_device_train_batch_size=1,\n",
    "        per_device_eval_batch_size=1,\n",
    "        num_train_epochs=1,\n",
    "        weight_decay=0.01,\n",
    "        push_to_hub=False,\n",
    "    )\n",
    "    \n",
    "    # Prepare trainer with configuration\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=training_args,\n",
    "        train_dataset=tokenized_squad[\"train\"],\n",
    "        eval_dataset=tokenized_squad[\"test\"],\n",
    "        processing_class=tokenizer,\n",
    "        data_collator=data_collator,\n",
    "    )\n",
    "    \n",
    "    trainer.train()\n",
    "\n",
    "    # Upload the fine-tuned model\n",
    "    if bucket:\n",
    "        (CloudPath(bucket) / model_name).upload_from(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bf5ab9ba-6054-40d6-839d-f84ff0fba8fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Runtime mpi-distributed must have trainer.kubeflow.org/framework label.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Name: deepspeed-distributed, Framework: deepspeed, Trainer Type: CustomTrainer\n",
      "Name: mlx-distributed, Framework: mlx, Trainer Type: CustomTrainer\n",
      "Name: torch-distributed, Framework: torch, Trainer Type: CustomTrainer\n",
      "Name: torchtune-llama3.2-1b, Framework: torchtune, Trainer Type: BuiltinTrainer\n",
      "Name: torchtune-llama3.2-3b, Framework: torchtune, Trainer Type: BuiltinTrainer\n"
     ]
    }
   ],
   "source": [
    "from kubeflow.trainer import TrainerClient, CustomTrainer\n",
    "\n",
    "for r in TrainerClient().list_runtimes():\n",
    "    print(f\"Name: {r.name}, Framework: {r.trainer.framework}, Trainer Type: {r.trainer.trainer_type.value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3fd0c5f-f359-4c6c-9f0e-2e91904579b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# To upload to object storage (S3, GCS or Azure Blob Storage), set the bucket with protocol, e.g., \"s3://my-bucket/folder\"\n",
    "BUCKET = None\n",
    "\n",
    "MODEL_NAME = \"distilbert-base-uncased\"\n",
    "args = {\n",
    "    \"model_name\": MODEL_NAME,\n",
    "    \"bucket\": BUCKET,\n",
    "}\n",
    "\n",
    "job_id = TrainerClient().train(\n",
    "    trainer=CustomTrainer(\n",
    "        func=train_distilbert,\n",
    "        func_args=args,\n",
    "        num_nodes=1,\n",
    "        packages_to_install=[\"datasets\", \"transformers[torch]\", \"cloudpathlib[all]\"],\n",
    "        resources_per_node={\n",
    "            \"cpu\": \"2\",\n",
    "            \"memory\": \"12Gi\",\n",
    "            # Uncomment this to distribute the TrainJob using GPU nodes\n",
    "            # \"nvidia.com/gpu\": 1,\n",
    "        },\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aabade1b-2c0b-492b-be97-03b4f0e037f0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'rafd89de924b'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train API generates a random TrainJob id.\n",
    "job_id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9294ea5",
   "metadata": {},
   "source": [
    "# Check the TrainJob details\n",
    "\n",
    "Use `list_jobs()` and `get_job()` APIs to get details about the created TrainJob and its steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fc5de9e8-f798-4cfd-bc6e-f17774cbd235",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TrainJob: rafd89de924b, Status: Unknown, Created at: 2025-04-29 01:22:14+00:00\n"
     ]
    }
   ],
   "source": [
    "for job in TrainerClient().list_jobs():\n",
    "    print(f\"TrainJob: {job.name}, Status: {job.status}, Created at: {job.creation_timestamp}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9f8ef95-b309-4987-9abb-760dc9c1e050",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wait for the running status.\n",
    "TrainerClient().wait_for_job_status(name=job_id, status={\"Running\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a3eec801",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step: node-0, Status: Running, Devices: cpu x 2\n"
     ]
    }
   ],
   "source": [
    "# We execute mpirun command on node-0, which functions as the MPI Launcher node.\n",
    "for c in TrainerClient().get_job(name=job_id).steps:\n",
    "    print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30f812d7",
   "metadata": {},
   "source": [
    "# Show the TrainJob logs\n",
    "\n",
    "Use `get_job_logs()` API to retrieve the TrainJob logs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d31f102f-8583-42c4-a6f7-f5a5eb0e7f98",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[node-0]: WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\n",
      "[node-0]: W0429 01:22:31.907000 1 site-packages/torch/distributed/run.py:793] \n",
      "[node-0]: W0429 01:22:31.907000 1 site-packages/torch/distributed/run.py:793] *****************************************\n",
      "[node-0]: W0429 01:22:31.907000 1 site-packages/torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n",
      "[node-0]: W0429 01:22:31.907000 1 site-packages/torch/distributed/run.py:793] *****************************************\n",
      "[node-0]: Distributed Training with WORLD_SIZE: 2, RANK: 1, LOCAL_RANK: 1.\n",
      "[node-0]: Distributed Training with WORLD_SIZE: 2, RANK: 0, LOCAL_RANK: 0.\n",
      "Generating train split: 100%|██████████| 87599/87599 [00:00<00:00, 627416.25 examples/s]\n",
      "Generating validation split: 100%|██████████| 10570/10570 [00:00<00:00, 1041555.11 examples/s]\n",
      "Map: 100%|██████████| 80/80 [00:00<00:00, 1361.10 examples/s]\n",
      "Map: 100%|██████████| 20/20 [00:00<00:00, 2123.27 examples/s]\n",
      "Map: 100%|██████████| 80/80 [00:00<00:00, 1030.86 examples/s]\n",
      "Map: 100%|██████████| 20/20 [00:00<00:00, 1877.53 examples/s]\n",
      "[node-0]: Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`\n",
      "[node-0]: Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']\n",
      "[node-0]: You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "[node-0]: Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']\n",
      "[node-0]: You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "  0%|          | 0/40 [00:00<?, ?it/s][rank0]:[W429 01:22:58.895439547 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\n",
      "[node-0]: [rank1]:[W429 01:22:58.895689005 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\n",
      "100%|██████████| 40/40 [02:36<00:00,  4.10s/it]\n",
      "  0%|          | 0/10 [00:00<?, ?it/s]\u001B[A\n",
      " 20%|██        | 2/10 [00:00<00:03,  2.54it/s]\u001B[A\n",
      " 30%|███       | 3/10 [00:01<00:03,  1.80it/s]\u001B[A\n",
      " 40%|████      | 4/10 [00:02<00:03,  1.55it/s]\u001B[A\n",
      " 50%|█████     | 5/10 [00:03<00:03,  1.43it/s]\u001B[A\n",
      " 60%|██████    | 6/10 [00:03<00:02,  1.37it/s]\u001B[A\n",
      " 70%|███████   | 7/10 [00:04<00:02,  1.32it/s]\u001B[A\n",
      " 80%|████████  | 8/10 [00:05<00:01,  1.30it/s]\u001B[A\n",
      " 90%|█████████ | 9/10 [00:06<00:00,  1.29it/s]\u001B[A\n",
      "                                               \u001B[A\n",
      "[node-0]: {'eval_loss': 5.543211936950684, 'eval_runtime': 7.9713, 'eval_samples_per_second': 2.509, 'eval_steps_per_second': 1.254, 'epoch': 1.0}\n",
      "100%|██████████| 40/40 [02:45<00:00,  4.10s/it]\n",
      "100%|██████████| 10/10 [00:07<00:00,  1.28it/s]\u001B[A\n",
      "[node-0]: {'train_runtime': 165.12, 'train_samples_per_second': 0.484, 'train_steps_per_second': 0.242, 'train_loss': 5.764264678955078, 'epoch': 1.0}\n",
      "100%|██████████| 40/40 [02:45<00:00,  4.13s/it]\u001B[A\n"
     ]
    }
   ],
   "source": [
    "for logline in TrainerClient().get_job_logs(job_id, follow=True):\n",
    "    print(logline)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66e6f04e-6e0c-4402-86d1-c0ccbe3f8602",
   "metadata": {},
   "source": [
    "# Inference\n",
    "\n",
    "Download the model and run inference on some examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1af042a5-abec-456c-a74b-fa78efb8000e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cloudpathlib import CloudPath\n",
    "from transformers import pipeline\n",
    "\n",
    "if BUCKET:\n",
    "    (CloudPath(BUCKET) / MODEL_NAME).download_to(MODEL_NAME)\n",
    "\n",
    "    question = \"How many programming languages does BLOOM support?\"\n",
    "    context = \"BLOOM has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages.\"\n",
    "\n",
    "    question_answerer = pipeline(\"question-answering\", model=f\"./{MODEL_NAME}/checkpoint-375\")\n",
    "    question_answerer(question=question, context=context)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32971b7b",
   "metadata": {},
   "source": [
    "# Clean up\n",
    "\n",
    "To delete the TrainJob you can use the `delete_job()` API and pass the generated `job_id`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "16608b6d-87bf-414d-a174-331f174c6add",
   "metadata": {},
   "outputs": [],
   "source": [
    "# _ = TrainerClient().delete_job(job_id)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
